{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# GPT-Neo inference with the HF's Transformers Library\n",
        "This notebook is a companion of chapter 3 of the \"Domain Specific LLMs in Action\" book, author Guglielmo Iozzia, [Manning Publications](https://www.manning.com/), 2024.  \n",
        "The code in this notebook is to introduce readers to the inference (text generation) with the [GPT-Neo model](https://github.com/EleutherAI/gpt-neo) using the Hugging Face's [Transformers library](https://github.com/huggingface/transformers). It can be executed in the Colab free tier with hardware acceleration (GPU).  \n",
        "More details about the code can be found in the book's chapter."
      ],
      "metadata": {
        "id": "pCU24Cr8XsWb"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Install the missing requirements in the Colab VM (HF's Accelerate only)."
      ],
      "metadata": {
        "id": "QZPtUdQxYoJm"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5cgXtRrrfjzF"
      },
      "outputs": [],
      "source": [
        "!pip install accelerate"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Download the GPT-Neo 2.7B model and the associated tokenizer from the HF's Hub. The model is loaded in full precision and is then loaded into the GPU."
      ],
      "metadata": {
        "id": "YN89RT34YyTl"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "from transformers import GPTNeoForCausalLM, GPT2Tokenizer\n",
        "\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "model_id = \"EleutherAI/gpt-neo-2.7B\"\n",
        "tokenizer = GPT2Tokenizer.from_pretrained(model_id)\n",
        "model = GPTNeoForCausalLM.from_pretrained(model_id, device_map=\"auto\")\n",
        "model.to(device)"
      ],
      "metadata": {
        "id": "NA52Ly1dfy05"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Verify where the model layers have been loaded (all in the GPU memory or also RAM and/or disk)."
      ],
      "metadata": {
        "id": "ljgwRcLZZCn3"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "model.hf_device_map"
      ],
      "metadata": {
        "id": "zm6qY99tJw-L"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Perform standard inference (text completion)."
      ],
      "metadata": {
        "id": "-wcXcAoYZU4J"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "prompt = \"The story so far: in the beginning, the universe was created.\"\n",
        "input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids.to(device)\n",
        "\n",
        "generated_ids = model.generate(input_ids,\n",
        "                               do_sample=True,\n",
        "                               temperature=0.9,\n",
        "                               max_length=200,\n",
        "                               pad_token_id=50256)\n",
        "generated_text = tokenizer.decode(generated_ids[0])\n",
        "print(generated_text)"
      ],
      "metadata": {
        "id": "1ZjuJkgMf3-e"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Do few-shot text classification (the model can generalize learning from few new and unseen examples."
      ],
      "metadata": {
        "id": "xImTdNsKZf0n"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "prompt = \"\"\"\n",
        "Sentence: This movie is very nice.\n",
        "Sentiment: positive\n",
        "\n",
        "#####\n",
        "\n",
        "Sentence: I hated this movie, it sucks.\n",
        "Sentiment: negative\n",
        "\n",
        "#####\n",
        "\n",
        "Sentence: This movie was actually pretty funny.\n",
        "Sentiment: positive\n",
        "\n",
        "#####\n",
        "\n",
        "Sentence: This movie could have been better.\n",
        "Sentiment: neutral\n",
        "\"\"\"\n",
        "input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids.to(device)\n",
        "\n",
        "generated_ids = model.generate(input_ids,\n",
        "                               do_sample=True,\n",
        "                               temperature=0.9,\n",
        "                               max_length=200,\n",
        "                               pad_token_id=50256)\n",
        "generated_text = tokenizer.decode(generated_ids[0])\n",
        "print(generated_text)"
      ],
      "metadata": {
        "id": "eNwj1bPlu0KZ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Do Python code generation."
      ],
      "metadata": {
        "id": "Vws8q09Nab5y"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "prompt = \"\"\"Instruction: Generate a Python function that lets you reverse a list of integers.\n",
        "\n",
        "Answer: \"\"\"\n",
        "input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids.to(device)\n",
        "\n",
        "generated_ids = model.generate(input_ids,\n",
        "                               do_sample=True,\n",
        "                               temperature=0.9,\n",
        "                               max_length=200,\n",
        "                               pad_token_id=50256\n",
        "                               )\n",
        "generated_text = tokenizer.decode(generated_ids[0])\n",
        "print(generated_text)"
      ],
      "metadata": {
        "id": "xjcA7JkIvibe"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Do batch text completion."
      ],
      "metadata": {
        "id": "1bEY9T56aibv"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "texts = [\"Once there was a man \", \"The weather today will be \", \"A great soccer player must \"]\n",
        "\n",
        "tokenizer.padding_side = \"left\"\n",
        "tokenizer.pad_token = tokenizer.eos_token\n",
        "encoding = tokenizer(texts, padding=True, return_tensors='pt').to(device)\n",
        "with torch.no_grad():\n",
        "    generated_ids = model.generate(**encoding,\n",
        "                                   do_sample=True,\n",
        "                                   temperature=0.9,\n",
        "                                   max_length=50,\n",
        "                                   pad_token_id=50256)\n",
        "generated_texts = tokenizer.batch_decode(\n",
        "    generated_ids, skip_special_tokens=True)\n",
        "\n",
        "for text in generated_texts:\n",
        "  print(\"---------\")\n",
        "  print(text)"
      ],
      "metadata": {
        "id": "VYLTgEhjwkes"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Benchmarking the model on text completion: comparing the cases where the KV cache is used to those where it isn't."
      ],
      "metadata": {
        "id": "refKUvsCawPP"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import time\n",
        "import numpy as np\n",
        "\n",
        "prompt = \"The story so far: in the beginning, the universe was created.\"\n",
        "input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids.to(device)\n",
        "\n",
        "for use_cache in (True, False):\n",
        "  times = []\n",
        "  for _ in range(20):\n",
        "    start = time.time()\n",
        "    generated_ids = model.generate(input_ids,\n",
        "                                  do_sample=True,\n",
        "                                  temperature=0.9,\n",
        "                                  max_length=200,\n",
        "                                  pad_token_id=50256,\n",
        "                                  use_cache=use_cache)\n",
        "    times.append(time.time() - start)\n",
        "  print(f\"{'Using' if use_cache else 'No'} KV cache: {round(np.mean(times), 3)} +- {round(np.std(times), 3)} seconds\")"
      ],
      "metadata": {
        "id": "CKVJVqwSTXgg"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Benchmarking the model's total generation time."
      ],
      "metadata": {
        "id": "sbDehxEOfA4m"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import time\n",
        "import numpy as np\n",
        "\n",
        "prompt = \"The story so far: in the beginning, the universe was created.\"\n",
        "input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids.to(device)\n",
        "\n",
        "max_length = 300\n",
        "times = []\n",
        "inference_runs = 21\n",
        "for _ in range(inference_runs):\n",
        "  start = time.time()\n",
        "  generated_ids = model.generate(input_ids,\n",
        "                                do_sample=True,\n",
        "                                temperature=0.9,\n",
        "                                max_length=max_length,\n",
        "                                pad_token_id=50256,\n",
        "                                )\n",
        "  times.append(time.time() - start)\n",
        "print(f\"Average Total Generation time: {round(np.mean(times[1:]), 3)} +- {round(np.std(times[1:]), 3)} seconds\")"
      ],
      "metadata": {
        "id": "wfuH8Djp0QgZ"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}